import torch
import torch.nn as nn
import torch.nn.functional as F


class DenoisingGenerator(nn.Module):
    """
    An improved symmetric U-Net generator for 64x64 images.
    - Replaces ConvTranspose2d with Upsample + Conv2d to reduce artifacts.
    - Uses LeakyReLU in the decoder for better gradient flow.
    """
    def __init__(self, eta_dim=100):
        super(DenoisingGenerator, self).__init__()
        
        # --- Encoder (Downsampling Path) ---
        # Input: (3, 64, 64)
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False), # (64, 32, 32)
        )
        self.enc2 = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), # (128, 16, 16)
            nn.BatchNorm2d(128),
        )
        self.enc3 = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False), # (256, 8, 8)
            nn.BatchNorm2d(256),
        )
        self.enc4 = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False), # (512, 4, 4)
            nn.BatchNorm2d(512),
        )
        self.enc5 = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1, bias=False), # (1024, 2, 2)
            nn.BatchNorm2d(1024),
        )

        # --- Bottleneck ---
        self.bottleneck = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 2048, kernel_size=4, stride=2, padding=1, bias=False), # (2048, 1, 1)
            nn.BatchNorm2d(2048),
        )

        # --- Decoder (Upsampling Path) using Upsample + Conv ---
        # Note: Changed to LeakyReLU and Upsample+Conv for better quality
        self.dec1 = self._make_decoder_block(2048 + eta_dim, 1024)
        self.dec2 = self._make_decoder_block(1024 + 1024, 512)
        self.dec3 = self._make_decoder_block(512 + 512, 256)
        self.dec4 = self._make_decoder_block(256 + 256, 128)
        self.dec5 = self._make_decoder_block(128 + 128, 64)
        self.dec6 = self._make_decoder_block(64 + 64, 64)
        
        # --- Final Output Layer ---
        self.final_conv = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1, bias=False),
            nn.Tanh() # Tanh maps output to [-1, 1].
        )
        
    def _make_decoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            # 1. Upsample
            nn.Upsample(scale_factor=2, mode='nearest'),
            # 2. Convolve
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )
        
    def forward(self, x, eta):
        # --- Encoder Path ---
        enc1_out = self.enc1(x)         # (64, 32, 32)
        enc2_out = self.enc2(enc1_out)  # (128, 16, 16)
        enc3_out = self.enc3(enc2_out)  # (256, 8, 8)
        enc4_out = self.enc4(enc3_out)  # (512, 4, 4)
        enc5_out = self.enc5(enc4_out)  # (1024, 2, 2)
        
        # --- Bottleneck ---
        bottleneck_out = self.bottleneck(enc5_out) # (2048, 1, 1)
        
        # Reshape and concatenate eta at the bottleneck
        eta_reshaped = eta.view(eta.size(0), -1, 1, 1)
        bottleneck_noisy = torch.cat([bottleneck_out, eta_reshaped], dim=1)
        
        # --- Decoder Path with Skip Connections ---
        dec1_out = self.dec1(bottleneck_noisy)
        dec2_out = self.dec2(torch.cat([dec1_out, enc5_out], dim=1))
        dec3_out = self.dec3(torch.cat([dec2_out, enc4_out], dim=1))
        dec4_out = self.dec4(torch.cat([dec3_out, enc3_out], dim=1))
        dec5_out = self.dec5(torch.cat([dec4_out, enc2_out], dim=1))
        dec6_out = self.dec6(torch.cat([dec5_out, enc1_out], dim=1))
        
        # --- Final Layer ---
        output = self.final_conv(dec6_out)
        
        return output